from torch import nn
from hydra.utils import instantiate

# https://github.com/QinbinLi/MOON/blob/main/model.py
class SimpleCNNMOON(nn.Module):
    def __init__(self, num_classes):
        super(SimpleCNNMOON, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)

        # for now, we hard coded this network
        # i.e. we fix the number of hidden layers i.e. 2 layers
        self.fc1 = nn.Linear(16 * 5 * 5, 120) # for cifar and cinic
        # self.fc1 = nn.Linear(16 * 4 * 4, 120) # Updated for Pathmnist to match the actual shape
        self.fc2 = nn.Linear(120, 84)

        self.projection_head = nn.Sequential(
            nn.Linear(84, 84),
            nn.ReLU(),
            nn.Linear(84, 84),
            nn.Linear(84, 256)
        )

        self.fc3 = nn.Linear(256, num_classes)

        self.features = ...

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        # print(x.shape)  # to check the shape
        x = x.view(-1, 16 * 5 * 5) # for cifar and cinic
        # x = x.view(-1, 16 * 4 * 4)  # Updated for Pathmnist to match the actual shape

        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))

        self.features = self.projection_head(x)

        y = self.fc3(self.features)
        return y

    def get_features(self):
        return self.features